Skip to content

Include ROCM support for CUDA extensions#12

Merged
amd-sriram merged 12 commits intomainfrom
rocm_rnnt_loss_feature
Feb 23, 2026
Merged

Include ROCM support for CUDA extensions#12
amd-sriram merged 12 commits intomainfrom
rocm_rnnt_loss_feature

Conversation

@amd-sriram
Copy link

@amd-sriram amd-sriram commented Feb 21, 2026

Motivation

Port cuda extensions to ROCm:

  • RNNTLoss
  • lfilter
  • forced align
  • CUCT
  • overdrive
  • rir (removed)

Technical Details

Hipification process

tools/setup_helpers/extension.py returns the list of extensions to be used in setup() call.
Hipification of cuda sources is performed by passing DHIPBLAS_V2 flag to the C++ and HIPCC compilers (via cxx and nvcc) flags to either torch's CppExtension or CudaExtension. This approach is also used in rocm/apex: https://github.com/ROCm/apex/blob/release/1.9.0/setup.py#L231.

extra_compile_args["cxx"].append("-DHIPBLAS_V2")
extra_compile_args["nvcc"] = ["-O3", "-DHIPBLAS_V2"]

In addition, cuda source files are added for _USE_ROCM flag.
e.g.
if _USE_CUDA or _USE_ROCM:
sources.append("iir_cuda.cu")

Fixing porting issues

The following changes have been made to fix the following errors:

1. TORCH_HIP_VERSION is not defined

/skishore/github/audio/src/libtorchaudio/utils_hip.cpp:20:10: error: ‘TORCH_HIP_VERSION’ was not declared in this scope; did you mean ‘TORCH_ABI_VERSION’? 

TORCH_HIP_VERSION is defined in tools/setup_helpers/extension.py , similiar to ttps://github.com/ROCm/pytorch/blob/develop/cmake/public/LoadHIP.cmake#L166 math(EXPR TORCH_HIP_VERSION "(${HIP_VERSION_MAJOR} * 100) + ${HIP_VERSION_MINOR}")

2. hip namespace not defined

/skishore/github/audio/src/libtorchaudio/rnnt/gpu/compute.hip:97:36: error: no member named 'hip' in namespace 'libtorchaudio' 
         98 |   options.stream_ = libtorchaudio::hip::getCurrentHIPStreamMasqueradingAsCUDA(logits.get_device_index());
            |                     ~~~~~~~~~~~~~~~^

Function defined in src/libtorchaudio/cuda_utils.h
namespace libtorchaudio::cuda is not hipified in this file, but it gets hipified in the call:

options.stream_ = libtorchaudio::cuda::getCurrentCUDAStream(logits.get_device_index());
becomes
options.stream_ = libtorchaudio::hip::getCurrentHIPStreamMasqueradingAsCUDA(logits.get_device_index());

Create hip_namespace_shim that maps the functions provided in src/libtorchaudio/cuda_utils.h.

This file is included in the following cuda source files:

  • src/libtorchaudio/rnnt/gpu/compute.cu
  • src/libtorchaudio/forced_align/gpu/compute.cu

3. kernel launch parameters are not proper

/skishore/github/audio/src/libtorchaudio/iir_hip.hip:75:8: error: too few arguments provided to function-like macro invocation 

         75 |        hipLaunchKernelGGL(( (iir_cu_kernel<scalar_t>), dim3(blocks), dim3(threads), 0, 0, 

Correct the parameters in THO_DISPATCH_V2 based on https://github.com/ROCm/pytorch/blob/develop/test/cpp_extensions/libtorch_agn_2_9_extension/csrc/kernel.cpp#L361

  THO_DISPATCH_V2(m.scalar_type(), "mv_tensor_accessor_cpu",
                  AT_WRAP(([&]() {
                    auto resa = Accessor_cpu<scalar_t, 1>(reinterpret_cast<scalar_t*>(res.data_ptr()), res.sizes().data(), res.strides().data());
                    auto ma = Accessor_cpu<scalar_t, 2>(reinterpret_cast<scalar_t*>(m.data_ptr()), m.sizes().data(), m.strides().data());
                    auto va = Accessor_cpu<scalar_t, 1>(reinterpret_cast<scalar_t*>(v.data_ptr()), v.sizes().data(), v.strides().data());
                    mv_tensor_accessor_kernel<Accessor_cpu, scalar_t>(resa, ma, va);
                  })),
                  AT_FLOATING_TYPES);

Removing @skipIfRocm from tests

  • test/torchaudio_unittest/functional/functional_cuda_test.py
  • test/torchaudio_unittest/functional/functional_impl.py
  • test/torchaudio_unittest/functional/torchscript_consistency_impl.py
  • test/torchaudio_unittest/transforms/torchscript_consistency_impl.py
  • test/torchaudio_unittest/functional/functional_cpu_test.py
  • test/torchaudio_unittest/transforms/autograd_test_impl.py

Test Plan

Run this branch in both Nvidia machine and AMD machine, check if it installs and run the unit tests for the cuda extensions:

pytest test/torchaudio_unittest/functional/functional_cuda_test.py -k test_rnnt  
pytest test/torchaudio_unittest/functional/torchscript_consistency_cuda_test.py -k test_rnnt 
pytest test/torchaudio_unittest/functional/autograd_cuda_test.py -k test_rnnt
pytest test/torchaudio_unittest/transforms/autograd_cuda_test.py -k test_rnnt
pytest test/torchaudio_unittest/transforms/torchscript_consistency_cuda_test.py -k test_rnnt 

pytest test/torchaudio_unittest/functional/functional_cuda_test.py -k test_lfilter
pytest test/torchaudio_unittest/functional/autograd_cuda_test.py -k test_lfilter
pytest test/torchaudio_unittest/functional/batch_consistency_test.py -k test_lfilter
pytest test/torchaudio_unittest/functional/torchscript_consistency_cuda_test.py -k test_lfilter

pytest test/torchaudio_unittest/functional/functional_cuda_test.py -k test_forced_align

pytest test/torchaudio_unittest/functional/autograd_cuda_test.py -k test_overdrive
pytest test/torchaudio_unittest/functional/batch_consistency_test.py -k test_overdrive
pytest test/torchaudio_unittest/functional/sox_compatibility_test.py -k test_overdrive
pytest test/torchaudio_unittest/functional/torchscript_consistency_cuda_test.py -k test_overdrive

pytest test/torchaudio_unittest/models/decoder/cuda_ctc_decoder_test.py

Test Result

Number of passed unit tests:

Syntax Description
RNNT loss 18, 1, 3, 3, 1
lfilter 19, 6, 2, 1
forced_align 120
overdrive 1, 2, 1, 2
cu ctc 3

Submission Checklist

@amd-sriram amd-sriram marked this pull request as ready for review February 21, 2026 18:31
@amd-sriram amd-sriram marked this pull request as draft February 23, 2026 11:27
@amd-sriram amd-sriram marked this pull request as ready for review February 23, 2026 19:26
@amd-sriram amd-sriram merged commit b9c7682 into main Feb 23, 2026
5 of 31 checks passed
@amd-sriram amd-sriram deleted the rocm_rnnt_loss_feature branch February 23, 2026 19:42
@amd-sriram amd-sriram changed the title Rocm rnnt loss feature Include ROCM support for CUDA extensions Feb 23, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant